# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Layer-wise Adaptive Moments (LAMB) optimizer.

See paper [Large Batch Optimization for Deep Learning: Training BERT in
76 minutes](https://arxiv.org/abs/1904.00962).
"""
import re
from typing import Optional, Union, Callable, List

import numpy as np
import tensorflow as tf, tf_keras

FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32]


@tf_keras.utils.register_keras_serializable(package="Addons")
class LAMB(tf_keras.optimizers.legacy.Optimizer):
  """Optimizer that implements the Layer-wise Adaptive Moments (LAMB).

  See paper [Large Batch Optimization for Deep Learning: Training BERT
  in 76 minutes](https://arxiv.org/abs/1904.00962).
  """

  def __init__(
      self,
      learning_rate: Union[FloatTensorLike, Callable] = 0.001,
      beta_1: FloatTensorLike = 0.9,
      beta_2: FloatTensorLike = 0.999,
      epsilon: FloatTensorLike = 1e-6,
      weight_decay_rate: FloatTensorLike = 0.0,
      exclude_from_weight_decay: Optional[List[str]] = None,
      exclude_from_layer_adaptation: Optional[List[str]] = None,
      name: str = "LAMB",
      **kwargs,
  ):
    """Construct a new LAMB optimizer.

    Args:
        learning_rate: A `Tensor` or a floating point value. or a schedule that
          is a `tf_keras.optimizers.schedules.LearningRateSchedule` The learning
          rate.
        beta_1: A `float` value or a constant `float` tensor. The exponential
          decay rate for the 1st moment estimates.
        beta_2: A `float` value or a constant `float` tensor. The exponential
          decay rate for the 2nd moment estimates.
        epsilon: A small constant for numerical stability.
        weight_decay_rate: weight decay rate.
        exclude_from_weight_decay: List of regex patterns of variables excluded
          from weight decay. Variables whose name contain a substring matching
          the pattern will be excluded.
        exclude_from_layer_adaptation: List of regex patterns of variables
          excluded from layer adaptation. Variables whose name contain a
          substring matching the pattern will be excluded.
        name: Optional name for the operations created when applying gradients.
          Defaults to "LAMB".
        **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
          `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is
          clip gradients by value, `decay` is included for backward
          compatibility to allow time inverse decay of learning rate. `lr` is
          included for backward compatibility, recommended to use
          `learning_rate` instead.
    """
    super().__init__(name, **kwargs)

    # Just adding the square of the weights to the loss function is *not*
    # the correct way of using L2 regularization/weight decay with Adam,
    # since that will interact with the m and v parameters in strange ways.
    #
    # Instead we want to decay the weights in a manner that doesn't interact
    # with the m/v parameters.
    self._set_hyper("weight_decay_rate", weight_decay_rate)
    self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))

    # This is learning rate decay for using keras learning rate schedule.
    self._set_hyper("decay", self._initial_decay)
    self._set_hyper("beta_1", beta_1)
    self._set_hyper("beta_2", beta_2)
    self.epsilon = epsilon or tf.backend_config.epsilon()
    self.exclude_from_weight_decay = exclude_from_weight_decay
    # exclude_from_layer_adaptation is set to exclude_from_weight_decay if
    # the arg is None.
    if exclude_from_layer_adaptation:
      self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
    else:
      self.exclude_from_layer_adaptation = exclude_from_weight_decay

  def _create_slots(self, var_list):
    # Create slots for the first and second moments.
    # Separate for-loops to respect the ordering of slot variables from v1.
    for var in var_list:
      self.add_slot(var, "m")
    for var in var_list:
      self.add_slot(var, "v")

  def _prepare_local(self, var_device, var_dtype, apply_state):
    super()._prepare_local(var_device, var_dtype, apply_state)

    local_step = tf.cast(self.iterations + 1, var_dtype)
    beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype))
    beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype))
    weight_decay_rate = tf.identity(
        self._get_hyper("weight_decay_rate", var_dtype)
    )
    beta_1_power = tf.pow(beta_1_t, local_step)
    beta_2_power = tf.pow(beta_2_t, local_step)
    apply_state[(var_device, var_dtype)].update(
        dict(
            weight_decay_rate=weight_decay_rate,
            epsilon=tf.convert_to_tensor(self.epsilon, var_dtype),
            beta_1_t=beta_1_t,
            beta_1_power=beta_1_power,
            one_minus_beta_1_t=1 - beta_1_t,
            beta_2_t=beta_2_t,
            beta_2_power=beta_2_power,
            one_minus_beta_2_t=1 - beta_2_t,
        )
    )

  def _resource_apply_dense(self, grad, var, apply_state=None):
    var_device, var_dtype = var.device, var.dtype.base_dtype
    coefficients = (apply_state or {}).get(
        (var_device, var_dtype)
    ) or self._fallback_apply_state(var_device, var_dtype)

    # m_t = beta1 * m + (1 - beta1) * g_t
    m = self.get_slot(var, "m")
    m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
    m_t = m * coefficients["beta_1_t"] + m_scaled_g_values
    m_t = m.assign(m_t, use_locking=self._use_locking)
    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
    v = self.get_slot(var, "v")
    v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
    v_t = v * coefficients["beta_2_t"] + v_scaled_g_values
    v_t = v.assign(v_t, use_locking=self._use_locking)

    m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
    v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])

    v_sqrt = tf.sqrt(v_t_hat)
    update = m_t_hat / (v_sqrt + coefficients["epsilon"])

    var_name = self._get_variable_name(var.name)
    if self._do_use_weight_decay(var_name):
      update += coefficients["weight_decay_rate"] * var

    ratio = 1.0
    if self._do_layer_adaptation(var_name):
      w_norm = tf.norm(var, ord=2)
      g_norm = tf.norm(update, ord=2)
      ratio = tf.where(
          tf.greater(w_norm, 0),
          tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
          1.0,
      )

    var_update = var - ratio * coefficients["lr_t"] * update
    return var.assign(var_update, use_locking=self._use_locking)

  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
    var_device, var_dtype = var.device, var.dtype.base_dtype
    coefficients = (apply_state or {}).get(
        (var_device, var_dtype)
    ) or self._fallback_apply_state(var_device, var_dtype)

    # m_t = beta1 * m + (1 - beta1) * g_t
    m = self.get_slot(var, "m")
    m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
    m_t = m.assign(m * coefficients["beta_1_t"], use_locking=self._use_locking)
    with tf.control_dependencies([m_t]):
      m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)

    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
    v = self.get_slot(var, "v")
    v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
    v_t = v.assign(v * coefficients["beta_2_t"], use_locking=self._use_locking)
    with tf.control_dependencies([v_t]):
      v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

    m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
    v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])

    v_sqrt = tf.sqrt(v_t_hat)
    update = m_t_hat / (v_sqrt + coefficients["epsilon"])

    var_name = self._get_variable_name(var.name)
    if self._do_use_weight_decay(var_name):
      update += coefficients["weight_decay_rate"] * var

    ratio = 1.0
    if self._do_layer_adaptation(var_name):
      w_norm = tf.norm(var, ord=2)
      g_norm = tf.norm(update, ord=2)
      ratio = tf.where(
          tf.greater(w_norm, 0),
          tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
          1.0,
      )

    var_update = var.assign_sub(
        ratio * coefficients["lr_t"] * update, use_locking=self._use_locking
    )
    return tf.group(*[var_update, m_t, v_t])

  def get_config(self):
    config = super().get_config()
    config.update({
        "learning_rate": self._serialize_hyperparameter("learning_rate"),
        "weight_decay_rate": self._serialize_hyperparameter(
            "weight_decay_rate"
        ),
        "decay": self._serialize_hyperparameter("decay"),
        "beta_1": self._serialize_hyperparameter("beta_1"),
        "beta_2": self._serialize_hyperparameter("beta_2"),
        "epsilon": self.epsilon,
    })
    return config

  def _do_use_weight_decay(self, param_name):
    """Whether to use L2 weight decay for `param_name`."""
    if self.exclude_from_weight_decay:
      for r in self.exclude_from_weight_decay:
        if re.search(r, param_name) is not None:
          return False
    return True

  def _do_layer_adaptation(self, param_name):
    """Whether to do layer-wise learning rate adaptation for `param_name`."""
    if self.exclude_from_layer_adaptation:
      for r in self.exclude_from_layer_adaptation:
        if re.search(r, param_name) is not None:
          return False
    return True

  def _get_variable_name(self, param_name):
    """Get the variable name from the tensor name."""
    m = re.match("^(.*):\\d+$", param_name)
    if m is not None:
      param_name = m.group(1)
    return param_name
